package com.blog.spark.job

import com.blog.spark.dao.StatDao
import com.blog.spark.entity.{BrowserOsTimes, DayAreaTimes, DayIpAreaTimes, DayOsTimes}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.functions._

import scala.collection.mutable.ListBuffer

/**
  * @description 最后的统计操作，跑在YARN上
  * @author yuyon26@126.com
  * @date 2018/10/5 20:28
  */
object StatJobYARN {

  def main(args: Array[String]): Unit = {

    if (args.length!=3){
      println("Usage: SparkStatCleanJobYARN <inputPath> <startDay> <endDay>")
      System.exit(1)
    }

    val Array(inputPath,startDay,endDay)=args
    val sparkSession = SparkSession.builder().config("spark.sql.sources.partitionColumnTypeInference.enabled", "false")
      .getOrCreate()

    val accessDF = sparkSession.read.format("parquet").load(inputPath).coalesce(1)

    //accessDF.printSchema()
    //accessDF.show(false)

    val days = startDay.toLong to (endDay.toLong)
    days.foreach(day => {
      //先删除
      StatDao.deleteDate(day.toString)
      import sparkSession.implicits._
      val commonDF = accessDF.drop("traffic").drop("time")
        .filter($"day".equalTo(day))
        .distinct()
      commonDF.cache()
      //IP地区统计
      ipAreaAccessTopNStat(sparkSession, commonDF)
      //地区统计
      areaAccessTopNStat(sparkSession, commonDF)
      //os平台统计
      platformAccessTonNStat(sparkSession, commonDF)
      //浏览器统计
      browserAccessTonNStat(sparkSession, commonDF)
      commonDF.unpersist(true)
    })
    sparkSession.stop()
  }

  /**
    * 按照ip地区统计信息
    *
    * @param sparkSession
    * @param accessDF
    */
  def ipAreaAccessTopNStat(sparkSession: SparkSession, commonDF: DataFrame) = {

    /**
      * 使用DataFrame方式统计
      */
    import sparkSession.implicits._
    val areaAccessTopNDF = commonDF.drop("browser").drop("os")
      .groupBy("day", "ip", "area").agg(count("ip").as("times"))
      .orderBy($"times".desc)
    //areaAccessTopNDF.show(false)

    /**
      * 使用sql统计
      */
    //    accessDF.createOrReplaceTempView("access_logs")
    //    val areaAccessTopNDF = sparkSession.sql("select day,ip,count(1) as times from access_logs " +
    //      "where day=" + day + " " +
    //      "group by day,ip order by times desc")
    //    areaAccessTopNDF.show(false)

    /**
      * 将统计结果写入到Mysql
      */
    try {
      areaAccessTopNDF.foreachPartition(data => {
        val list = new ListBuffer[DayIpAreaTimes]

        data.foreach(info => {
          val day = info.getAs[String]("day")
          val ip = info.getAs[String]("ip")
          val area = info.getAs[String]("area")
          val times = info.getAs[Long]("times")
          list.append(DayIpAreaTimes(day, ip, area, times))
        })
        //批量插入mysql
        StatDao.insertIPAreaBatch(list)
      })
    } catch {
      case e: Exception => e.printStackTrace()
    }
  }

  /**
    * 按照地区统计信息
    *
    * @param sparkSession
    * @param accessDF
    */
  def areaAccessTopNStat(sparkSession: SparkSession, commonDF: DataFrame) = {

    /**
      * 使用DataFrame方式统计
      */
    import sparkSession.implicits._
    val areaAccessTopNDF = commonDF.drop("browser").drop("os")
      .selectExpr("day", "area", "SUBSTRING_INDEX(ip, '.', 3) as ip_")
      .distinct()
      .groupBy("day", "area").agg(count("area").as("times"))
      .orderBy($"times".desc)

    /**
      * 将统计结果写入到Mysql
      */
    try {
      areaAccessTopNDF.foreachPartition(data => {
        val list = new ListBuffer[DayAreaTimes]

        data.foreach(info => {
          val day = info.getAs[String]("day")
          val area = info.getAs[String]("area")
          val times = info.getAs[Long]("times")
          list.append(DayAreaTimes(day, area, times))
        })
        //批量插入mysql
        StatDao.insertAreaBatch(list)
      })
    } catch {
      case e: Exception => e.printStackTrace()
    }
  }

  /**
    * os平台统计
    *
    * @param sparkSession
    * @param accessDF
    * @param toString
    */
  def platformAccessTonNStat(sparkSession: SparkSession, commonDF: DataFrame): Unit = {
    import sparkSession.implicits._
    val areaAccessTopNDF = commonDF.drop("area").drop("browser")
      .selectExpr("day", "os", "SUBSTRING_INDEX(ip, '.', 3) as ip_")
      .distinct()
      .groupBy("day", "os").agg(count("os").as("times"))
      .orderBy($"times".desc)

    /**
      * 将统计结果写入到Mysql
      */
    try {
      areaAccessTopNDF.foreachPartition(data => {
        val list = new ListBuffer[DayOsTimes]

        data.foreach(info => {
          val day = info.getAs[String]("day")
          val os = info.getAs[String]("os")
          val times = info.getAs[Long]("times")
          list.append(DayOsTimes(day, os, times))
        })
        //批量插入mysql
        StatDao.insertOsBatch(list)
      })
    } catch {
      case e: Exception => e.printStackTrace()
    }
  }

  /**
    * 浏览器统计
    *
    * @param sparkSession
    * @param accessDF
    * @param day
    */
  def browserAccessTonNStat(sparkSession: SparkSession, commonDF: DataFrame): Unit = {
    /**
      * 使用DataFrame方式统计
      */
    import sparkSession.implicits._
    val areaAccessTopNDF = commonDF.drop("time").drop("area").drop("traffic").drop("os")
      .selectExpr("day", "browser", "SUBSTRING_INDEX(ip, '.', 3) as ip_")
      .distinct()
      .groupBy("day", "browser").agg(count("browser").as("times"))
      .orderBy($"times".desc)

    /**
      * 将统计结果写入到Mysql
      */
    try {
      areaAccessTopNDF.foreachPartition(data => {
        val list = new ListBuffer[BrowserOsTimes]

        data.foreach(info => {
          val day = info.getAs[String]("day")
          val browser = info.getAs[String]("browser")
          val times = info.getAs[Long]("times")
          list.append(BrowserOsTimes(day, browser, times))
        })
        //批量插入mysql
        StatDao.insertBrowserBatch(list)
      })
    } catch {
      case e: Exception => e.printStackTrace()
    }
  }

}

